{
 "metadata": {
  "name": "ch3_nltk"
 },
 "nbformat": 3,
 "nbformat_minor": 0,
 "worksheets": [
  {
   "cells": [
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "# Building a Naive Bayes spam classifier with NLTK\n",
      "\n",
      "We'll follow the same logic as the program from chapter 3 of *Machine Learning for Hackers*, but we'll do so with a workflow more suited to NLTK's functions. So instead of creating a term-document matrix, and building our own Naive Bayes classifier, we'll build a `features` $\\rightarrow$ `label` association for each training e-mail, and feed a list of these to NLTK's `NaiveBayesClassifier` function.\n",
      "\n",
      "Some good references for this are:\n",
      "\n",
      "Bird, Steven and et. al., *Natural Language Processing with Python*\n",
      "\n",
      "Perkins, Jacob, *Python Text Processing with NLTK 2.0 Cookbook*\n",
      "\n",
      "\n"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from pandas import *\n",
      "import numpy as np\n",
      "import os\n",
      "import re\n",
      "from nltk import NaiveBayesClassifier\n",
      "import nltk.classify\n",
      "from nltk.tokenize import wordpunct_tokenize\n",
      "from nltk.corpus import stopwords\n",
      "from collections import defaultdict"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 1
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "## Loading the e-mail messages into lists\n",
      "\n",
      "E-mails of each type --spam, \"easy\" ham, and \"hard\" ham-- are split across two directories per type. We'll use the first directories of spam and \"easy\" ham to train the classifier. Then we'll test the classifier on the e-mails in the second directories."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "data_path = os.path.abspath(os.path.join('.', 'data'))\n",
      "spam_path = os.path.join(data_path, 'spam')\n",
      "spam2_path = os.path.join(data_path, 'spam_2') \n",
      "easyham_path = os.path.join(data_path, 'easy_ham')\n",
      "easyham2_path = os.path.join(data_path, 'easy_ham_2')\n",
      "hardham_path = os.path.join(data_path, 'hard_ham')\n",
      "hardham2_path = os.path.join(data_path, 'hard_ham_2')"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 2
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "The following function loads all the e-mail files in a directory, extracts their message bodies and returns them in a list."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def get_msgdir(path):\n",
      "    '''\n",
      "    Read all messages from files in a directory into\n",
      "    a list where each item is the text of a message. \n",
      "    \n",
      "    Simply gets a list of e-mail files in a directory,\n",
      "    and iterates get_msg() over them.\n",
      "\n",
      "    Returns a list of strings.\n",
      "    '''\n",
      "    filelist = os.listdir(path)\n",
      "    filelist = filter(lambda x: x != 'cmds', filelist)\n",
      "    all_msgs =[get_msg(os.path.join(path, f)) for f in filelist]\n",
      "    return all_msgs\n",
      "\n",
      "def get_msg(path):\n",
      "    '''\n",
      "    Read in the 'message' portion of an e-mail, given\n",
      "    its file path. The 'message' text begins after the first\n",
      "    blank line; above is header information.\n",
      "\n",
      "    Returns a string.\n",
      "    '''\n",
      "    with open(path, 'rU') as con:\n",
      "        msg = con.readlines()\n",
      "        first_blank_index = msg.index('\\n')\n",
      "        msg = msg[(first_blank_index + 1): ]\n",
      "        return ''.join(msg) "
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 3
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "We'll use the functions to make training and testing message lists for each type of e-mail."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "train_spam_messages    = get_msgdir(spam_path)\n",
      "train_easyham_messages = get_msgdir(easyham_path)\n",
      "# Only keep the first 500 to balance w/ number of spam messages.\n",
      "train_easyham_messages = train_easyham_messages[:500]\n",
      "train_hardham_messages = get_msgdir(hardham_path)\n",
      "\n",
      "test_spam_messages    = get_msgdir(spam2_path)\n",
      "test_easyham_messages = get_msgdir(easyham2_path)\n",
      "test_hardham_messages = get_msgdir(hardham2_path)\n"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 4
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "## Extracting word features from the e-mail messages\n",
      "\n",
      "Each e-mail in our classifier's training data will have a label (\"spam\" or \"ham\") and a feature set. For this application, we're just going to use a feature set that is just a set of the unique words in the e-mail. Below, we'll turn this into a dictionary to feed into the `NaiveBayesClassifier`, but first, let's get the set."
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "### Parsing and tokenizing the e-mails\n",
      "\n",
      "We're going to use NLTK's `wordpunct_tokenize` function to break the message into tokens. This splits tokens at white space and (most) punctuation marks, and returns the punctuation along with the tokens on each side. So `\"I don't know. Do you?\"` becomes `[\"I\", \"don\", \"'\", \"t\", \"know\", \".\", \"Do\", \"you\", \"?\"]`.\n",
      "\n",
      "If you look through some of the training e-mails in `train_spam_messages` and `train_ham_messages`, you'll notice a few features that make extracting words tricky.\n",
      "\n",
      "First, there are a couple of odd text artefacts. The string '3D' shows up in strange places in HTML attributes and other places, and we'll remove these. Furthermore there seem to be some mid-word line wraps flagged with an '=' where the word is broken across lines. For example, the work 'apple' might be split across lines like 'app=\\nle'.  We want to strip these out so we can recover 'apple'. We'll want to deal with all these first, before we apply the tokenizer.\n",
      "\n",
      "Second, there's a lot of HTML in the messages. We'll have to decide first whether we want to keep HTML info in our set of words. If we do, we'll apply `wordpunct_tokenize` to some HTML, for example:\n",
      "\n",
      "`\"<HEAD></HEAD><BODY><!-- Comment -->\"`\n",
      "\n",
      "and it will tokenize to:\n",
      "\n",
      "`[\"<\", \"HEAD\", \"></\", \"HEAD\", \"><\", \"BODY\", \"><!--\", \"Comment\", \"-->\"]`\n",
      "\n",
      "So if we drop the punctuation tokens, and get the unique set of what remains, we'd have `{\"HEAD\", \"BODY\", \"Comment\"}`, which seems like what we'd want. For example, it's nice that this method doesn't make, `<HEAD>` and `</HEAD>` separate words in our set, but just captures the existence of this tag with the term `\"HEAD\"`. It might be a problem that we won't distinguish between the HTML tag `<HEAD>` and \"head\" used as an English word in the message. But for the moment I'm willing to bet that sort of conflation won't have a big affect on the classifier.\n",
      "\n",
      "If we don't want to count HTML information in our set of words, we can set the `strip_html` to `True`, and we'll take all the HTML tags out before tokenizing.\n",
      "\n",
      "Lastly we'll strip out any \"stopwords\" from the set. Stopwords are highly common, therefore low information words, like \"a\", \"the\", \"he\", etc. Below I'll use `stopwords`, downloaded from NLTK's corpus library, with a minor modifications to deal with this. (In other programs I've used the stopwords exported from R's `tm` package.)\n",
      "\n",
      "Note that because our tokenizer splits contractions (\"she'll\" $\\rightarrow$ \"she\", \"ll\"), we'd like to drop the ends (\"ll\"). Some of these may be picked up in NLTK's `stopwords` list, others we'll manually add. It's an imperfect, but easy solution. There are more sophisticated ways of dealing with this which are overkill for our purposes.\n",
      "\n",
      "Tokenizing, as perhaps you can tell, is a non-trivial operation. NLTK has a host of other tokenizing functions of varying sophistication, and even lets you define your own tokenizing rule using regex."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def get_msg_words(msg, stopwords = [], strip_html = False):\n",
      "    '''\n",
      "    Returns the set of unique words contained in an e-mail message. Excludes \n",
      "    any that are in an optionally-provided list. \n",
      "\n",
      "    NLTK's 'wordpunct' tokenizer is used, and this will break contractions.\n",
      "    For example, don't -> (don, ', t). Therefore, it's advisable to supply\n",
      "    a stopwords list that includes contraction parts, like 'don' and 't'.\n",
      "    '''\n",
      "    \n",
      "    # Strip out weird '3D' artefacts.\n",
      "    msg = re.sub('3D', '', msg)\n",
      "    \n",
      "    # Strip out html tags and attributes and html character codes,\n",
      "    # like &nbsp; and &lt;.\n",
      "    if strip_html:\n",
      "        msg = re.sub('<(.|\\n)*?>', ' ', msg)\n",
      "        msg = re.sub('&\\w+;', ' ', msg)\n",
      "    \n",
      "    # wordpunct_tokenize doesn't split on underscores. We don't\n",
      "    # want to strip them, since the token first_name may be informative\n",
      "    # moreso than 'first' and 'name' apart. But there are tokens with long\n",
      "    # underscore strings (e.g. 'name_________'). We'll just replace the\n",
      "    # multiple underscores with a single one, since 'name_____' is probably\n",
      "    # not distinct from 'name___' or 'name_' in identifying spam.\n",
      "    msg = re.sub('_+', '_', msg)\n",
      "\n",
      "    # Note, remove '=' symbols before tokenizing, since these are\n",
      "    # sometimes occur within words to indicate, e.g., line-wrapping.\n",
      "    msg_words = set(wordpunct_tokenize(msg.replace('=\\n', '').lower()))\n",
      "     \n",
      "    # Get rid of stopwords\n",
      "    msg_words = msg_words.difference(stopwords)\n",
      "    \n",
      "    # Get rid of punctuation tokens, numbers, and single letters.\n",
      "    msg_words = [w for w in msg_words if re.search('[a-zA-Z]', w) and len(w) > 1]\n",
      "    \n",
      "    return msg_words"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 5
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "The stopwords list. While it contains some terms to account for contractions, we'll add a couple more."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "sw = stopwords.words('english')\n",
      "sw.extend(['ll', 've'])"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 6
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "### Making a `(features, label)` list\n",
      "\n",
      "The `NaiveBayesClassifier` function trains on data that's of the form `[(features1, label1), features2, label2), ..., (featuresN, labelN)]` where `featuresi` is a dictionary of features for e-mail `i` and `labeli` is the label for e-mail `i' (\"spam\" or \"ham\"). \n",
      "\n",
      "The function `features_from_messages` iterates through the messages creating this list, but calls an outside function to create the `features` for each e-mail. This makes the function modular in case we decide to try out some other method of extracting features from the e-mails besides the set of word. It then combines the features to the e-mail's label in a tuple and adds the tuple to the list.\n",
      "\n",
      "The `word_indicator` function calls `get_msg_words()` to get an e-mail's words as a set, then creates a dictionary with entries `{word: True}` for each word in the set. This is a little counter-intuitive (since we don't have `{word: False}` entries for words not in the set) but `NaiveBayesClassifier` knows how to handle it."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def features_from_messages(messages, label, feature_extractor, **kwargs):\n",
      "    '''\n",
      "    Make a (features, label) tuple for each message in a list of a certain,\n",
      "    label of e-mails ('spam', 'ham') and return a list of these tuples.\n",
      "\n",
      "    Note every e-mail in 'messages' should have the same label.\n",
      "    '''\n",
      "    features_labels = []\n",
      "    for msg in messages:\n",
      "        features = feature_extractor(msg, **kwargs)\n",
      "        features_labels.append((features, label))\n",
      "    return features_labels\n",
      "\n",
      "def word_indicator(msg, **kwargs):\n",
      "    '''\n",
      "    Create a dictionary of entries {word: True} for every unique\n",
      "    word in a message.\n",
      "\n",
      "    Note **kwargs are options to the word-set creator,\n",
      "    get_msg_words().\n",
      "    '''\n",
      "    features = defaultdict(list)\n",
      "    msg_words = get_msg_words(msg, **kwargs)\n",
      "    for  w in msg_words:\n",
      "            features[w] = True\n",
      "    return features"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 7
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "## Training and evaluating the classifier"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "The following is just a helper function to make training and testing data from the messages. Notice we combine the training spam and training ham into a single set, since we need to train our classifier on data with both spam and ham in it."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def make_train_test_sets(feature_extractor, **kwargs):\n",
      "    '''\n",
      "    Make (feature, label) lists for each of the training \n",
      "    and testing lists.\n",
      "    '''\n",
      "    train_spam = features_from_messages(train_spam_messages, 'spam', \n",
      "                                        feature_extractor, **kwargs)\n",
      "    train_ham = features_from_messages(train_easyham_messages, 'ham', \n",
      "                                       feature_extractor, **kwargs)\n",
      "    train_set = train_spam + train_ham\n",
      "\n",
      "    test_spam = features_from_messages(test_spam_messages, 'spam',\n",
      "                                       feature_extractor, **kwargs)\n",
      "\n",
      "    test_ham = features_from_messages(test_easyham_messages, 'ham',\n",
      "                                      feature_extractor, **kwargs)\n",
      "\n",
      "    test_hardham = features_from_messages(test_hardham_messages, 'ham',\n",
      "                                          feature_extractor, **kwargs)\n",
      "    \n",
      "    return train_set, test_spam, test_ham, test_hardham"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 8
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Finally we make a function to run the classifier and check its accuracy on test data. After training the classifier, we check how accurately it classifies data in new spam, \"easy\" ham, and \"hard\" ham datasets. \n",
      "\n",
      "The function then prints out the results of `NaiveBayesClassifiers`'s handy `show_most_informative_features` method. This shows which features are most unique to one label or another. For example, if \"viagra\" shows up in 500 of the spam e-mails, but only 2 of the \"ham\" e-mails in the training set, then the method will show that \"viagra\" is one of the most informative features with a `spam:ham` ratio of 250:1."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def check_classifier(feature_extractor, **kwargs):\n",
      "    '''\n",
      "    Train the classifier on the training spam and ham, then check its accuracy\n",
      "    on the test data, and show the classifier's most informative features.\n",
      "    '''\n",
      "    \n",
      "    # Make training and testing sets of (features, label) data\n",
      "    train_set, test_spam, test_ham, test_hardham = \\\n",
      "        make_train_test_sets(feature_extractor, **kwargs)\n",
      "    \n",
      "    # Train the classifier on the training set\n",
      "    classifier = NaiveBayesClassifier.train(train_set)\n",
      "    \n",
      "    # How accurate is the classifier on the test sets?\n",
      "    print ('Test Spam accuracy: {0:.2f}%'\n",
      "       .format(100 * nltk.classify.accuracy(classifier, test_spam)))\n",
      "    print ('Test Ham accuracy: {0:.2f}%'\n",
      "       .format(100 * nltk.classify.accuracy(classifier, test_ham)))\n",
      "    print ('Test Hard Ham accuracy: {0:.2f}%'\n",
      "       .format(100 * nltk.classify.accuracy(classifier, test_hardham)))\n",
      "\n",
      "    # Show the top 20 informative features\n",
      "    print classifier.show_most_informative_features(20)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 9
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "First, we run the classifier keeping all the HTML information in the feature set. The accuracy at identifying spam and ham is very high. Unsurprisingly, we do a lousy job at identifying hard ham. \n",
      "\n",
      "This may be because our training set is relying too much on HTML tags to identify spam. As we can see, HTML info comprises all the `most_informative_features`."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "check_classifier(word_indicator, stopwords = sw)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "Test Spam accuracy: 98.71%\n",
        "Test Ham accuracy: 97.07%"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "\n",
        "Test Hard Ham accuracy: 13.71%"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "\n",
        "Most Informative Features\n",
        "                   align = True             spam : ham    =    119.7 : 1.0"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "\n",
        "                      tr = True             spam : ham    =    115.7 : 1.0\n",
        "                      td = True             spam : ham    =    111.7 : 1.0\n",
        "                   arial = True             spam : ham    =    107.7 : 1.0\n",
        "             cellpadding = True             spam : ham    =     97.0 : 1.0\n",
        "             cellspacing = True             spam : ham    =     94.3 : 1.0\n",
        "                     img = True             spam : ham    =     80.3 : 1.0\n",
        "                 bgcolor = True             spam : ham    =     67.4 : 1.0\n",
        "                    href = True             spam : ham    =     67.0 : 1.0\n",
        "                    sans = True             spam : ham    =     62.3 : 1.0\n",
        "                 colspan = True             spam : ham    =     61.0 : 1.0\n",
        "                    font = True             spam : ham    =     61.0 : 1.0\n",
        "                  valign = True             spam : ham    =     60.3 : 1.0\n",
        "                      br = True             spam : ham    =     59.6 : 1.0\n",
        "                 verdana = True             spam : ham    =     57.7 : 1.0\n",
        "                    nbsp = True             spam : ham    =     57.4 : 1.0\n",
        "                   color = True             spam : ham    =     54.4 : 1.0\n",
        "                  ff0000 = True             spam : ham    =     53.0 : 1.0\n",
        "                  ffffff = True             spam : ham    =     50.6 : 1.0\n",
        "                  border = True             spam : ham    =     49.6 : 1.0\n",
        "None\n"
       ]
      }
     ],
     "prompt_number": 10
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "If we try just using the text of the messages, without the HTML tags and information, we lose a tiny bit of accuracy in identifying spam but do much better with the hard ham."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "check_classifier(word_indicator, stopwords = sw, strip_html = True)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "Test Spam accuracy: 96.64%\n",
        "Test Ham accuracy: 98.64%"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "\n",
        "Test Hard Ham accuracy: 56.05%"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "\n",
        "Most Informative Features\n",
        "                    dear = True             spam : ham    =     41.7 : 1.0"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "\n",
        "                     aug = True              ham : spam   =     38.3 : 1.0\n",
        "              guaranteed = True             spam : ham    =     35.0 : 1.0\n",
        "              assistance = True             spam : ham    =     29.7 : 1.0\n",
        "                  groups = True              ham : spam   =     27.9 : 1.0\n",
        "                mailings = True             spam : ham    =     25.0 : 1.0\n",
        "               sincerely = True             spam : ham    =     23.0 : 1.0\n",
        "                    fill = True             spam : ham    =     23.0 : 1.0\n",
        "                mortgage = True             spam : ham    =     21.7 : 1.0\n",
        "                     sir = True             spam : ham    =     21.0 : 1.0\n",
        "                 sponsor = True              ham : spam   =     20.3 : 1.0\n",
        "                 article = True              ham : spam   =     20.3 : 1.0\n",
        "                  assist = True             spam : ham    =     19.0 : 1.0\n",
        "                  income = True             spam : ham    =     18.6 : 1.0\n",
        "                     tue = True              ham : spam   =     18.3 : 1.0\n",
        "                   mails = True             spam : ham    =     18.3 : 1.0\n",
        "                     iso = True             spam : ham    =     17.7 : 1.0\n",
        "                   admin = True              ham : spam   =     17.7 : 1.0\n",
        "                  monday = True              ham : spam   =     17.7 : 1.0\n",
        "                    earn = True             spam : ham    =     17.0 : 1.0\n",
        "None\n"
       ]
      }
     ],
     "prompt_number": 11
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [],
     "language": "python",
     "metadata": {},
     "outputs": []
    }
   ],
   "metadata": {}
  }
 ]
}